Exclude small-k and small-n Matmul nodes from Int8 quantization#1256
Exclude small-k and small-n Matmul nodes from Int8 quantization#1256nv-samcheng wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughExtended MatMul exclusion logic to also exclude "small-gemm" MatMul nodes when inferred N or K is below 16. Added helper Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/unit/onnx/quantization/test_graph_utils.py (1)
119-182: Add targeted tests forGemm(transB=1)and inference-based exclusion.Nice coverage for MatMul shape-inference. Please add one case validating K extraction when
op="Gemm"withtransB=1, plus one test for_exclude_matmuls_by_inference(sharedinp_bvariable case) to lock in the new runtime-output extension path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/onnx/quantization/test_graph_utils.py` around lines 119 - 182, Add two unit tests in tests/unit/onnx/quantization/test_graph_utils.py: one that constructs a Gemm model with op="Gemm" and attribute transB=1 and asserts _get_inp_b_k_dim on its node returns the correct K (e.g., when B is constant with shape [..., K, N] transposed), and a second test that exercises _exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the same inp_b Variable (use calibration_shapes only for "A" and provide an output_map or runtime-output scenario so the code path that reads K from runtime-output is used) and assert the expected node id is excluded; reference helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and _exclude_matmuls_by_shape_inference to locate relevant setup and ensure names/ids match existing tests (e.g., "MatMul_0").
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/quantization/graph_utils.py`:
- Around line 1235-1261: The _get_inp_b_k_dim function currently always reads K
from axis -2 which is wrong for Gemm when transB=1; update _get_inp_b_k_dim to
detect transB (default 0 for MatMul) from the node (check for attribute "transB"
on matmul_node) and compute k_axis = -1 if transB > 0 else -2, then use k_axis
when indexing into inp_b.values.shape, inp_b_info.type.tensor_type.shape.dim,
and output_map[inp_b.name].shape so all three fallback paths respect
transposition; also add unit tests that cover Gemm nodes with transB=1 to
prevent regressions.
- Around line 1343-1348: The code adds matmul outputs and second-input Variable
names to model.graph.output without deduplication, which can create duplicate
output names; update the logic (in the block handling matmul_nodes / uses of
matmul_node.outputs[0].name and matmul_node.inputs[1].name) to track
already-added output names (e.g., a set of names) and only call
model.graph.output.extend with onnx.ValueInfoProto for a name if it is not
already present in that set (and add it to the set after extending), ensuring
you still skip Constants by checking isinstance(matmul_node.inputs[1],
Variable).
---
Nitpick comments:
In `@tests/unit/onnx/quantization/test_graph_utils.py`:
- Around line 119-182: Add two unit tests in
tests/unit/onnx/quantization/test_graph_utils.py: one that constructs a Gemm
model with op="Gemm" and attribute transB=1 and asserts _get_inp_b_k_dim on its
node returns the correct K (e.g., when B is constant with shape [..., K, N]
transposed), and a second test that exercises
_exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the
same inp_b Variable (use calibration_shapes only for "A" and provide an
output_map or runtime-output scenario so the code path that reads K from
runtime-output is used) and assert the expected node id is excluded; reference
helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and
_exclude_matmuls_by_shape_inference to locate relevant setup and ensure
names/ids match existing tests (e.g., "MatMul_0").
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 3a5d8843-1a90-424d-a931-a88d63dc0fa0
📒 Files selected for processing (2)
modelopt/onnx/quantization/graph_utils.pytests/unit/onnx/quantization/test_graph_utils.py
What does this PR do?
Exclude small-dimension MatMul nodes from INT8 quantization. MatMuls with N or K < 16 cannot efficiently use INT8, causing performance regressions.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Bug Fixes
Tests